Spatial Exante Framework for Targeting Sowing Dates using Causal ML and Policy Learning: Eastern India Example

Author

Maxwell Mkondiwa

1 Introduction

In this notebook, I use a causal machine learning estimator, i.e., multi-armed causal random forest with augmented inverse propensity score weights (Athey et al 2019), to estimate conditional average treatment effects (CATES) for agronomic practices. These CATEs are estimated for each individual farm thereby providing personalized estimates of the potential effectiveness of the practices. I then use a debiased robust estimator in a policy tree optimization (Athey and Wager 2021) to generate optimal recommendations in the form of agronomic practices that maximize potential yield gains.

2 Preliminaries

load("LDS_Public_Workspace.RData")

LDSestim=LDS
table(LDSestim$Sowing_Date_Schedule) 

T5_16Dec T4_15Dec T3_30Nov T2_20Nov T1_10Nov 
     665     1696     3167     1704      416 

2.1 Graphics

# Bar graphs showing percentage of farmers adopting these practices  

library(tidyverse) 
library(ggplot2)  

bar_chart=function(dat,var){   dat|>     drop_na({{var}})|>     mutate({{var}}:=factor({{var}})|>fct_infreq())|>     ggplot()+     geom_bar(aes(y={{var}}),fill="dodgerblue4")+     theme_minimal(base_size = 16) }   

sow_plot=bar_chart(LDSestim,Sowing_Date_Schedule)+labs(y="Sowing dates") 
 
sow_plot

library(ggpubr) 
library(tidyverse) 

#Sowing dates 

SowingDate_Options_Errorplot=   LDSestim%>%   
  drop_na(Sowing_Date_Schedule) %>%   
  ggerrorplot(x = "Sowing_Date_Schedule", y = "L.tonPerHectare",add = "mean", error.plot = "errorbar", color="steelblue", ggtheme=theme_bw())+   
  labs(x="Sowing date options",y="Wheat yield (t/ha)")+   
  theme_minimal(base_size = 16)+
  coord_flip()  
SowingDate_Options_Errorplot 

2.2 Descriptives

setwd("D:/OneDrive/CIMMYT/Papers/IO5.3.1.CropResponseModels/WheatResponse/Policytree_Models") 

library(fBasics)

summ_stats <- fBasics::basicStats(LDSestim[,c("L.tonPerHectare","G.q5305_irrigTimes","Nperha","P2O5perha","Weedmanaged","variety_type_NMWV","Sowing_Date_Early")]) 

summ_stats <- as.data.frame(t(summ_stats)) 

# Rename some of the columns for convenience 

summ_stats <- summ_stats[c("Mean", "Stdev", "Minimum", "1. Quartile", "Median",  "3. Quartile", "Maximum")] %>%   
rename("Lower quartile" = '1. Quartile', "Upper quartile"= "3. Quartile")  

summ_stats 
                         Mean     Stdev Minimum Lower quartile    Median
L.tonPerHectare      2.990635  0.854990     0.2         2.4000   3.00000
G.q5305_irrigTimes   2.286183  0.767177     1.0         2.0000   2.00000
Nperha             130.220556 37.038394     0.0       105.1852 132.54321
P2O5perha           59.038865 19.629879     0.0        45.4321  59.77908
Weedmanaged          0.758990  0.427724     0.0         1.0000   1.00000
variety_type_NMWV    0.530097  0.499126     0.0         0.0000   1.00000
Sowing_Date_Early    0.277197  0.447644     0.0         0.0000   0.00000
                   Upper quartile  Maximum
L.tonPerHectare           3.43000   6.5000
G.q5305_irrigTimes        3.00000   5.0000
Nperha                  156.00000 298.4691
P2O5perha                72.69136 212.9630
Weedmanaged               1.00000   1.0000
variety_type_NMWV         1.00000   1.0000
Sowing_Date_Early         1.00000   1.0000

2.3 Sowing dates at district level

2.3.0.1 Percent of farmers

table(LDSestim$A.q103_district, LDSestim$Sowing_Date_Schedule)
                
                 T5_16Dec T4_15Dec T3_30Nov T2_20Nov T1_10Nov
  Arah                 37       89       55       22        5
  Araria                0       11       59       39        8
  Arwal                84       66       15       16        0
  Aurangabad           41      112       34        7        0
  Balia                 0       26       92       83       15
  Banka                73       71       23        8        1
  Begusarai             0        1       20       90      102
  Bhagalpur            41       41       68       46       11
  Buxar                24       64       84       31        4
  Chandauli            26      141       30       10        1
  Deoria                0        8       65       97       39
  EastChamparan        22       48       96       38        1
  Gaya                 25      103       43       20        2
  Gazipur               3       38      140       27        2
  Gopalganj             6       14      104       75       11
  Gorakhpur             0        2       72      122       14
  Jehanabad            47       82       23       15       22
  Kaimur               55      115       31        3        0
  Katihar               7        3       69       31        5
  Khagaria              4       15       51       71       64
  Kushinagar            7       11      115       58        4
  Lakhisarai           37       48       47       42       21
  Madhepura             3       15      123       21        7
  Maharajganj           2       10      122       68        8
  Mau                   3       28      116       52        5
  Munger               20       61       90       27       12
  Muzaffarpur           0       14      119       65        6
  Nalanda               9       33      123       26        5
  Patna                16       66       87       31        3
  Purnia                0        0        5        2        1
  Rohtas               38      106       45        5        2
  Saharsa              22       70       56       14        1
  Samastipur            0        2      119       73        8
  Saran                 4       21      131       51        2
  Sheohar               3       53      104       48        1
  Siddharthnagar        0        3      108       67       15
  Siwan                 0       19       98       81        0
  Supaul                3       47      109       44        3
  Vaishali              0        0      153       42        1
  WestChamparan         3       39      123       36        4
library(modelsummary)

Sowpercent=datasummary_crosstab(A.q103_district ~ Sowing_Date_Schedule, data = LDSestim,output = 'data.frame')

library(reactable)
library(htmltools)
library(fontawesome)

htmltools::browsable(
  tagList(
    tags$button(
      tagList(fontawesome::fa("download"), "Download as CSV"),
      onclick = "Reactable.downloadDataCSV('Sowpercent', 'Sowpercent.csv')"
    ),

    reactable(
      Sowpercent,
      searchable = TRUE,
      defaultPageSize = 38,
      elementId = "Sowpercent"
    )
  )
)

2.3.0.2 Area share by sowing dates

library(data.table)

LDSestim_DT=data.table(LDSestim)

SampleAcres_dist_sowdate <- LDSestim_DT[, (SampleAcres_by_sowdate <-sum(C.q306_cropLarestAreaAcre,na.rm=TRUE)), by = c("Sowing_Date_Schedule", "A.q103_district")]

library(dplyr)
SampleAcres_dist_sowdate <- rename(SampleAcres_dist_sowdate, SampleAcres_by_sowdate = V1)

SampleAcres_dist <- LDSestim_DT[, (SampleAcres <-sum(C.q306_cropLarestAreaAcre,na.rm=TRUE)), by = c("A.q103_district")]

SampleAcres_dist<- rename(SampleAcres_dist, SampleAcres = V1)

SampleAcres_dist_sowdate_merged=merge(SampleAcres_dist_sowdate,SampleAcres_dist,by="A.q103_district")

SampleAcres_dist_sowdate_merged$Share_acres_by_sowdate=SampleAcres_dist_sowdate_merged$SampleAcres_by_sowdate/SampleAcres_dist_sowdate_merged$SampleAcres

Share_acres_by_sowdate_group=datasummary(A.q103_district*Share_acres_by_sowdate ~Sowing_Date_Schedule * (Mean),data =SampleAcres_dist_sowdate_merged, output="data.frame")

library(reactable)
library(htmltools)
library(fontawesome)

htmltools::browsable(
  tagList(
    tags$button(
      tagList(fontawesome::fa("download"), "Download as CSV"),
      onclick = "Reactable.downloadDataCSV('Share_acres_by_sowdate_group', 'Share_acres_by_sowdate_group.csv')"
    ),

    reactable(
      Share_acres_by_sowdate_group,
      searchable = TRUE,
      defaultPageSize = 38,
      elementId = "Share_acres_by_sowdate_group"
    )
  )
)

3 Causal Random Forest Model

library(grf)
library(policytree)

LDSestim_sow=subset(LDSestim, select=c("Sowing_Date_Schedule","L.tonPerHectare","I.q5505_weedSeverity_num","I.q5509_diseaseSeverity_num","I.q5506_insectSeverity_num","I.q5502_droughtSeverity_num",                                       "Nperha","P2O5perha","variety_type_NMWV","G.q5305_irrigTimes","A.q111_fGenderdum","Weedmanaged","temp","precip","wc2.1_30s_elev","M.q708_marketDistance","nitrogen_0.5cm","sand_0.5cm", "soc_5.15cm","O.largestPlotGPS.Latitude","O.largestPlotGPS.Longitude"))

library(tidyr)
LDSestim_sow=LDSestim_sow %>% drop_na()


Y_cf_sowing=as.vector(LDSestim_sow$L.tonPerHectare)
## Causal random forest -----------------

X_cf_sowing=subset(LDSestim_sow, select=c("I.q5505_weedSeverity_num","I.q5509_diseaseSeverity_num","I.q5506_insectSeverity_num",
                                                  "Nperha","P2O5perha","variety_type_NMWV","G.q5305_irrigTimes","A.q111_fGenderdum","Weedmanaged","temp","precip","wc2.1_30s_elev",
                                                       "M.q708_marketDistance","nitrogen_0.5cm","sand_0.5cm", "soc_5.15cm","O.largestPlotGPS.Latitude","O.largestPlotGPS.Longitude"))


W_cf_sowing <- as.factor(LDSestim_sow$Sowing_Date_Schedule)

W.multi_sowing.forest <- probability_forest(X_cf_sowing, W_cf_sowing,
  equalize.cluster.weights = FALSE,
  seed = 2
)
W.hat.multi.all_sowing <- predict(W.multi_sowing.forest, estimate.variance = TRUE)$predictions



Y.multi_sowing.forest <- regression_forest(X_cf_sowing, Y_cf_sowing,
  equalize.cluster.weights = FALSE,
  seed = 2
)

print(Y.multi_sowing.forest)
GRF forest object of type regression_forest 
Number of trees: 2000 
Number of training samples: 7562 
Variable importance: 
    1     2     3     4     5     6     7     8     9    10    11    12    13 
0.000 0.000 0.001 0.021 0.015 0.651 0.193 0.000 0.064 0.001 0.001 0.006 0.001 
   14    15    16    17    18 
0.001 0.004 0.002 0.023 0.015 
varimp.multi_sowing <- variable_importance(Y.multi_sowing.forest)
Y.hat.multi.all_sowing <- predict(Y.multi_sowing.forest, estimate.variance = TRUE)$predictions



multi_sowing.forest <- multi_arm_causal_forest(X = X_cf_sowing, Y = Y_cf_sowing, W = W_cf_sowing ,W.hat=W.hat.multi.all_sowing,Y.hat=Y.hat.multi.all_sowing,seed=2) 

varimp.multi_sowing_cf <- variable_importance(multi_sowing.forest)

multi_sowing_ate=average_treatment_effect(multi_sowing.forest, method="AIPW")
multi_sowing_ate
                     estimate    std.err            contrast outcome
T4_15Dec - T5_16Dec 0.2358984 0.02577961 T4_15Dec - T5_16Dec     Y.1
T3_30Nov - T5_16Dec 0.4245938 0.02284938 T3_30Nov - T5_16Dec     Y.1
T2_20Nov - T5_16Dec 0.5638900 0.02572197 T2_20Nov - T5_16Dec     Y.1
T1_10Nov - T5_16Dec 0.6981874 0.03905789 T1_10Nov - T5_16Dec     Y.1
varimp.multi_sowing_cf <- variable_importance(multi_sowing.forest)
vars_sowing=c("I.q5505_weedSeverity_num","I.q5509_diseaseSeverity_num","I.q5506_insectSeverity_num",
                                                  "Nperha","P2O5perha","variety_type_NMWV","G.q5305_irrigTimes","A.q111_fGenderdum","Weedmanaged","temp","precip","wc2.1_30s_elev",
                                                       "M.q708_marketDistance","nitrogen_0.5cm","sand_0.5cm", "soc_5.15cm","O.largestPlotGPS.Latitude","O.largestPlotGPS.Longitude")

## variable importance plot ----------------------------------------------------
varimpvars_sowing=as.data.frame(cbind(varimp.multi_sowing_cf,vars_sowing))
names(varimpvars_sowing)[1]="Variableimportance_sowing"
varimpvars_sowing$Variableimportance_sowing=formatC(varimpvars_sowing$Variableimportance_sowing, digits = 2, format = "f")
varimpvars_sowing$Variableimportance_sowing=as.numeric(varimpvars_sowing$Variableimportance_sowing)
varimpplotRF_sowing=ggplot(varimpvars_sowing,aes(x=reorder(vars_sowing,Variableimportance_sowing),y=Variableimportance_sowing))+
   geom_jitter(color="steelblue")+
   coord_flip()+
   labs(x="Variables",y="Variable importance")
 previous_theme <- theme_set(theme_bw(base_size = 16))
 varimpplotRF_sowing

# Policy tree --------------------------------------
DR.scores_sowing <- double_robust_scores(multi_sowing.forest)

tr_sowing <- policy_tree(X_cf_sowing, DR.scores_sowing, depth = 2) 
plot(tr_sowing)
tr_sowing3 <- hybrid_policy_tree(X_cf_sowing, DR.scores_sowing, depth = 3) 
tr_sowing3
policy_tree object 
Tree depth:  3 
Actions:  1: T5_16Dec 2: T4_15Dec 3: T3_30Nov 4: T2_20Nov 5: T1_10Nov 
Variable splits: 
(1) split_variable: O.largestPlotGPS.Latitude  split_value: 25.84 
  (2) split_variable: P2O5perha  split_value: 62.4691 
    (4) split_variable: O.largestPlotGPS.Latitude  split_value: 25.42 
      (8) * action: 5 
      (9) * action: 4 
    (5) split_variable: soc_5.15cm  split_value: 9.1 
      (10) * action: 4 
      (11) * action: 5 
  (3) split_variable: O.largestPlotGPS.Longitude  split_value: 83.47 
    (6) split_variable: precip  split_value: 710.4 
      (12) * action: 5 
      (13) * action: 4 
    (7) split_variable: temp  split_value: 25.6 
      (14) * action: 2 
      (15) * action: 5 
plot(tr_sowing3)
tr_sowing4 <- hybrid_policy_tree(X_cf_sowing, DR.scores_sowing, depth = 4) 
tr_sowing4
policy_tree object 
Tree depth:  4 
Actions:  1: T5_16Dec 2: T4_15Dec 3: T3_30Nov 4: T2_20Nov 5: T1_10Nov 
Variable splits: 
(1) split_variable: O.largestPlotGPS.Latitude  split_value: 25.84 
  (2) split_variable: P2O5perha  split_value: 62.4691 
    (4) split_variable: soc_5.15cm  split_value: 16.2 
      (8) split_variable: O.largestPlotGPS.Latitude  split_value: 25.42 
        (16) * action: 5 
        (17) * action: 4 
      (9) split_variable: I.q5509_diseaseSeverity_num  split_value: 2 
        (18) * action: 4 
        (19) * action: 5 
    (5) split_variable: I.q5505_weedSeverity_num  split_value: 3 
      (10) split_variable: soc_5.15cm  split_value: 9.1 
        (20) * action: 4 
        (21) * action: 5 
      (11) split_variable: precip  split_value: 785.4 
        (22) * action: 4 
        (23) * action: 5 
  (3) split_variable: O.largestPlotGPS.Longitude  split_value: 83.47 
    (6) split_variable: P2O5perha  split_value: 56.7901 
      (12) split_variable: precip  split_value: 710.4 
        (24) * action: 5 
        (25) * action: 4 
      (13) split_variable: temp  split_value: 26.05 
        (26) * action: 5 
        (27) * action: 4 
    (7) split_variable: P2O5perha  split_value: 74.963 
      (14) split_variable: temp  split_value: 25.6 
        (28) * action: 2 
        (29) * action: 5 
      (15) split_variable: temp  split_value: 26.1333 
        (30) * action: 5 
        (31) * action: 4 
plot(tr_sowing4)
tr_assignment_sowing=LDSestim_sow

tr_assignment_sowing$depth2 <- predict(tr_sowing, X_cf_sowing)
table(tr_assignment_sowing$depth2)

   4    5 
2031 5531 
tr_assignment_sowing$depth3 <- predict(tr_sowing3, X_cf_sowing)
table(tr_assignment_sowing$depth3)

   2    4    5 
 183 1252 6127 
tr_assignment_sowing$depth4 <- predict(tr_sowing4, X_cf_sowing)
table(tr_assignment_sowing$depth4)

   2    4    5 
 159 1565 5838 

4 Policy learning algorithm for treatment assignment

library(rgdal)

tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2==1]="T5_16Dec"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2==2]="T4_15Dec"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2==3]="T3_30Nov"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2==4]="T2_20Nov"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2==5]="T1_10Nov"

tr_assignment_sowingsp= SpatialPointsDataFrame(cbind(tr_assignment_sowing$O.largestPlotGPS.Longitude,tr_assignment_sowing$O.largestPlotGPS.Latitude),data=tr_assignment_sowing,proj4string=CRS("+proj=longlat +datum=WGS84"))

library(mapview)
mapviewOptions(fgb = FALSE)
tr_assignment_sowingspmapview=mapview(tr_assignment_sowingsp,zcol="depth2_cat",layer.name="Recommended sowing dates")
tr_assignment_sowingspmapview

5 Distributional analysis

library(ggridges)
library(dplyr)
tau.multi_sowing.forest=predict(multi_sowing.forest, target.sample = "all",estimate.variance=TRUE)

tau.multi_sowing.forest=as.data.frame(tau.multi_sowing.forest)


tau.multi_sowing.forest_X=data.frame(LDSestim_sow,tau.multi_sowing.forest)


# Ridges -------------------
tau.multi_sowing.forest_pred=tau.multi_sowing.forest[,1:4]

library(dplyr)
library(reshape2)
tau.multi_sowing.forest_pred=rename(tau.multi_sowing.forest_pred,"T4_15Dec - T5_16Dec"="predictions.T4_15Dec...T5_16Dec.Y.1")

tau.multi_sowing.forest_pred=rename(tau.multi_sowing.forest_pred,"T3_30Nov-T5_16Dec"="predictions.T3_30Nov...T5_16Dec.Y.1")

tau.multi_sowing.forest_pred=rename(tau.multi_sowing.forest_pred,"T2_20Nov-T5_16Dec"="predictions.T2_20Nov...T5_16Dec.Y.1")

tau.multi_sowing.forest_pred=rename(tau.multi_sowing.forest_pred,"T1_10Nov-T5_16Dec"="predictions.T1_10Nov...T5_16Dec.Y.1")


tau.multi_sowing.forest_pred_long=reshape2::melt(tau.multi_sowing.forest_pred[,1:4])

ggplot(tau.multi_sowing.forest_pred_long, aes(x=value, y=variable, fill = factor(stat(quantile)))) +
  stat_density_ridges(
    geom = "density_ridges_gradient", calc_ecdf = TRUE,
    quantiles = 4, quantile_lines = TRUE
  ) +
  scale_fill_viridis_d(name = "Quartiles")+
  theme_bw(base_size = 16)+labs(x="Wheat yield gain(t/ha)",y="Sowing date options")

6 Transition matrix of the policy change

tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2 == 1] <- "T5_16Dec"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2 == 2] <- "T4_15Dec"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2 == 3] <- "T3_30Nov"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2 == 4] <- "T2_20Nov"
tr_assignment_sowing$depth2_cat[tr_assignment_sowing$depth2 == 5] <- "T1_10Nov"


library(ggalluvial)
library(data.table)
tr_assignment_sowingDT = data.table(tr_assignment_sowing)
TransitionMatrix_sowing <- tr_assignment_sowingDT[, (sum <- .N), by = c("Sowing_Date_Schedule", "depth2_cat")]
library(dplyr)
TransitionMatrix_sowing <- rename(TransitionMatrix_sowing, Freq = V1)

library(scales)
transitionmatrixplot_sowing <- ggplot(
    data = TransitionMatrix_sowing,
    aes(axis1 = Sowing_Date_Schedule, axis2 = depth2_cat, y = Freq)
) +
    geom_alluvium(aes(fill = depth2_cat)) +
    geom_stratum() +
    # geom_text(stat="stratum", aes(label=after_stat(stratum),nudge_y =5))+
    geom_text(stat = "stratum", aes(label = paste(after_stat(stratum), percent(after_stat(prop))))) +
    scale_x_discrete(
        limits = c("Sowing_Date_Schedule", "depth2_cat"),
        expand = c(0.15, 0.05)
    ) +
    scale_fill_viridis_d() +
    theme_void(base_size = 20) +
    theme(legend.position = "none")

transitionmatrixplot_sowing